import torch
from quant_model import ResNet18
from train_fun import load_dataset

def load_model_weights(model, file_path):
    # 加载模型权重
    state_dict = torch.load(file_path, map_location='cpu')
    model.load_state_dict(state_dict)
    return model

def test_inference(model, testloader, device, num_images=None):
    model.eval()  # 设置模型为评估模式
    correct = 0
    total = 0
    cnt = 0
    with torch.no_grad():
        for i, data in enumerate(testloader):
            images, labels = data
            for j in range(images.size(0)):
                if num_images is not None and cnt >= num_images:
                    break
                image = images[j].unsqueeze(0).to(device)  # 处理单张图片
                label = labels[j].unsqueeze(0).to(device)
                output = model(image)
                _, predicted = torch.max(output.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()
                cnt +=1
    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')

if __name__ == "__main__":
    epoch = 4  # 你可以根据需要修改这个值
    file_path = f'../model_params/witin/fixed/ResNet18_param_{epoch}_quan.pth'
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18().to(device)
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    model = load_model_weights(model, file_path)
    torch.quantization.convert(model, inplace=True)
    _, testloader = load_dataset()
    
    num_images = 100  # 设置为 None 以测试全部样本，或设置为你想要测试的图片数量
    test_inference(model, testloader, device, num_images=num_images)